Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRN2 Meshes and Configurations #916

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

apoorvtintin
Copy link
Contributor

This PR adds meshes for TRN2/1 for Fuji models and transformer layer configuration favorable to Neuron.

Neuron supports stacked transformer and GroupedQKVLinear instead of FusedGroupedQKVLinear for Grouped Query Attention (GQA)

This is a newer version of the PR #885. This PR resolved all comments and requested changes mentioned in the linked PR.

@apoorvtintin apoorvtintin requested review from ruomingp, markblee and a team as code owners January 10, 2025 00:48
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 6b404f6 to 3f7c840 Compare January 10, 2025 00:53
@apoorvtintin
Copy link
Contributor Author

Added a ModelConfigModifier that overrides the class for a module. Allowing different model configurations based on Model size and platform.

Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making such change, overall looks good. A few nit comments.

continue
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to extract a common util function named something like
def replace_module_recursive(target_modules:str, config_key: str, target_config) and make it applied to both here and RematSpecModifier

Copy link
Contributor Author

@apoorvtintin apoorvtintin Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extracted a helper function, let me know if this looks good

axlearn/common/trainer_config_modifier_test.py Outdated Show resolved Hide resolved
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 708fc5e to d481132 Compare January 10, 2025 07:38
@apoorvtintin
Copy link
Contributor Author

apoorvtintin commented Jan 10, 2025

Added ParameterPartitionSpecModifier for parameters to shard Embeddings in a vocab parallel manner as described in Megatron LM.

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 5be50d7 to 9b10041 Compare January 10, 2025 08:10
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved

found_module, parent_module, key_in_parent = find_target_module(module_name, cfg)

# Copy configurations from the config being replaced on a best effort basis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, this behavior is not explained in the class comments. So we are not replacing but merging the configs? Maybe we should support a merge function instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the goal is to change the config to a similar module. This means most of the configuration can be reused from before. Essentially replacing the module but merging the config. Let me extract out a merge function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Abstracted out a merge function let me know if more changes are needed for this.

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from 9b10041 to 0f0a530 Compare January 12, 2025 07:06
@apoorvtintin
Copy link
Contributor Author

@ruomingp Thank you for the review, I have addressed all your comments, please let me know if more changes are needed.

@apoorvtintin apoorvtintin requested a review from ruomingp January 12, 2025 07:08
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
Comment on lines 239 to 244
for module_name, model_cfg in self._model_cfg_modifications.items():
found_module = _find_target_module(module_name, cfg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In utils.py we have get_recursively and set_recursively for Nested[...]. I wonder if it will be useful to add corresponding methods to ConfigBase. Then we can do something like:

Suggested change
for module_name, model_cfg in self._model_cfg_modifications.items():
found_module = _find_target_module(module_name, cfg)
for cfg_path, cfg_modification in self._model_cfg_modifications.items():
child_cfg = cfg.get_recursively(cfg_path)
child_cfg = cfg_modification(child_cfg, path=cfg_path)
cfg.set_recursively(cfg_path, value=child_cfg)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added get_recursively and set_recursively functions to ConfigBase. Let me know if it looks good

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if an alternative (which aims to simplify the ConfigBase api) is to do something similar to Python's sorted; we allow utils.get_resursively to take a value fn:

# Default behavior is to use key lookup:
utils.get_recursively(..., value_fn=lambda k,v: v[k])

# Custom behavior can be attribute lookup:
utils.get_recursively(..., value_fn=lambda k,v: getattr(v,k))

A benefit is that other non-config instances can also leverage get_recursively.

@apoorvtintin
Copy link
Contributor Author

apoorvtintin commented Jan 15, 2025

Added a more flexible PartitionSpecModifier that can modify multiple partition_spec attributes in a single module config.

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 45c7df1 to 8807856 Compare January 17, 2025 01:17
Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly lgtm, some minor comments.

axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
axlearn/common/trainer_config_modifier.py Outdated Show resolved Hide resolved
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from 8807856 to 25510d6 Compare January 22, 2025 01:39
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from eec33eb to 86bafa8 Compare January 23, 2025 05:40
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch 2 times, most recently from 4661492 to fe96240 Compare January 23, 2025 16:38
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from fe96240 to da90757 Compare January 24, 2025 23:32
@apoorvtintin apoorvtintin requested a review from ruomingp January 24, 2025 23:40
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-boilerplate branch from da90757 to 7e2e5f2 Compare January 27, 2025 10:07
@apoorvtintin
Copy link
Contributor Author

apoorvtintin commented Jan 27, 2025

@ruomingp and @kelvin-zou thank you both for the review. I addressed all comments, please let me know if anymore changes are needed. PR looks clean now.

key: str

def recursive_traverse(self, key_path: Sequence[str]) -> tuple[Any, str]:
"""Recursively traverse the config to find the target key.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see other comment re get_recursively; also, I wonder whether we actually need recursion here (seems like a loop would be simpler).

@@ -146,6 +137,110 @@ def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
return cfg


class ModelConfigModifier(ConfigModifier):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which part of this class is specific to model? It seems to take generic modifications?

"""Configure ModelConfigModifier.

Attributes:
model_cfg_modifications: A mapping from module path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outdated?

"""Merge configurations from the config being replaced on a best effort basis.

Merge Rules:
- Klass is not changed, use target cfg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Klass is not changed, use target cfg
- Klass is not changed, use target cfg.

Please end all sentences with punctuations.

Comment on lines +171 to +172
target_cfg: configuration that will replace found_module.
found_module: existing configuration whose class will be replaced
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
target_cfg: configuration that will replace found_module.
found_module: existing configuration whose class will be replaced
target_cfg: Configuration that will replace found_module.
found_module: Existing configuration whose class will be replaced

if version != Version.V1:
trn2_model_modifications.append(
ModelConfigModifier.default_config().set(
target_config="model.decoder.transformer.layer.self_attention.attention."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A downside of representing these deeply nested configs as string paths is that they are brittle, and can quickly become outdated.

Have we considered using cfg.visit to achieve some of these modifications (e.g.,

def set_layer_norm_eps_recursively(cfg: ConfigBase, eps: float, set_only_if_none: bool = False):
)?

(A bit late to review, so apologies if this discussion has already taken place.)

# The key string.
key: str

def recursive_traverse(self, key_path: Sequence[str]) -> tuple[Any, str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a public method?

traverse_result = self.recursive_traverse(key_path)
return getattr(traverse_result.parent, traverse_result.key)

def set_recursively(self, key_path: Sequence[str], new_value: Any):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please name consistently with

def set_recursively(
x: NestedTensor,
*,
value: Tensor,
.

Suggested change
def set_recursively(self, key_path: Sequence[str], new_value: Any):
def set_recursively(self, path: Sequence[str], *, value: Any):

value = getattr(self, target_key)
return value.recursive_traverse(key_path[1:])

def get_recursively(self, key_path: Sequence[str]) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_recursively(self, key_path: Sequence[str]) -> Any:
def get_recursively(self, path: Sequence[str]) -> Any:

"""Recursively find the target key in the config and return its value.

Args:
key_path: A sequence of keys for indexing to get the target value.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can path be empty? Maybe it can return self if path is empty?

"""Recursively find the target key in the config and set its value.

Args:
key_path: A sequence of keys for indexing to set the target value.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can path be empty?

Raises:
ValueError: A key in key_path is not found.
"""
traverse_result = self.recursive_traverse(key_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do something like:

Suggested change
traverse_result = self.recursive_traverse(key_path)
if not path:
raise ValueError(...)
parent = self.get_recursively(path[:-1])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants